{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fetSbKRrczhD"
      },
      "source": [
        "Copyright 2020 DeepMind Technologies Limited.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zuYtT7GpuVgQ"
      },
      "source": [
        "\u003cp align=\"center\"\u003e\n",
        "  \u003ch1 align=\"center\"\u003eTAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement\u003c/h1\u003e\n",
        "  \u003cp align=\"center\"\u003e\n",
        "    \u003ca href=\"http://www.carldoersch.com/\"\u003eCarl Doersch\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://yangyi02.github.io/\"\u003eYi Yang\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.com/citations?user=Jvi_XPAAAAAJ\"\u003eMel Vecerik\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.com/citations?user=cnbENAEAAAAJ\"\u003eDilara Gokay\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://www.robots.ox.ac.uk/~ankush/\"\u003eAnkush Gupta\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"http://people.csail.mit.edu/yusuf/\"\u003eYusuf Aytar\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ\"\u003eJoao Carreira\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://www.robots.ox.ac.uk/~az/\"\u003eAndrew Zisserman\u003c/a\u003e\n",
        "  \u003c/p\u003e\n",
        "  \u003ch3 align=\"center\"\u003e\u003ca href=\"https://arxiv.org/abs/2306.08637\"\u003ePaper\u003c/a\u003e | \u003ca href=\"https://deepmind-tapir.github.io\"\u003eProject Page\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet\"\u003eGitHub\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet/tree/main#running-tapir-locally\"\u003eLive Demo\u003c/a\u003e \u003c/h3\u003e\n",
        "  \u003cdiv align=\"center\"\u003e\u003c/div\u003e\n",
        "\u003c/p\u003e\n",
        "\n",
        "\u003cp align=\"center\"\u003e\n",
        "  \u003ca href=\"\"\u003e\n",
        "    \u003cimg src=\"https://storage.googleapis.com/dm-tapnet/swaying_gif.gif\" alt=\"Logo\" width=\"50%\"\u003e\n",
        "  \u003c/a\u003e\n",
        "\u003c/p\u003e\n",
        "\n",
        ""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mCmDvfFvxnGB"
      },
      "outputs": [],
      "source": [
        "# @title Install Code and Dependencies {form-width: \"25%\"}\n",
        "!pip install git+https://github.com/google-deepmind/tapnet.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zyEo9-Kv78S7"
      },
      "outputs": [],
      "source": [
        "MODEL_TYPE = 'bootstapir'  # 'tapir' or 'bootstapir'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HaswJZMq9B3c"
      },
      "outputs": [],
      "source": [
        "# @title Download Model {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/checkpoints\n",
        "\n",
        "if MODEL_TYPE == \"tapir\":\n",
        "  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy\n",
        "else:\n",
        "  !wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.npy\n",
        "\n",
        "%ls tapnet/checkpoints"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FxlHY242m-6Q"
      },
      "outputs": [],
      "source": [
        "# @title Imports {form-width: \"25%\"}\n",
        "%matplotlib widget\n",
        "from google.colab import output\n",
        "import jax\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy as media\n",
        "import numpy as np\n",
        "from tapnet.models import tapir_model\n",
        "from tapnet.utils import model_utils\n",
        "from tapnet.utils import transforms\n",
        "from tapnet.utils import viz_utils\n",
        "\n",
        "output.enable_custom_widget_manager()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7rfy2yobnHqw"
      },
      "outputs": [],
      "source": [
        "# @title Load Checkpoint {form-width: \"25%\"}\n",
        "\n",
        "if MODEL_TYPE == 'tapir':\n",
        "  checkpoint_path = 'tapnet/checkpoints/tapir_checkpoint_panning.npy'\n",
        "else:\n",
        "  checkpoint_path = 'tapnet/checkpoints/bootstapir_checkpoint_v2.npy'\n",
        "ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()\n",
        "params, state = ckpt_state['params'], ckpt_state['state']\n",
        "\n",
        "kwargs = dict(bilinear_interp_with_depthwise_conv=False, pyramid_level=0)\n",
        "if MODEL_TYPE == 'bootstapir':\n",
        "  kwargs.update(\n",
        "      dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)\n",
        "  )\n",
        "tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y1HkuvyT6SJF"
      },
      "outputs": [],
      "source": [
        "# @title Load an Exemplar Video {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/examplar_videos\n",
        "\n",
        "!wget -P tapnet/examplar_videos http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
        "\n",
        "video = media.read_video(\"tapnet/examplar_videos/horsejump-high.mp4\")\n",
        "media.show_video(video, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ogRTRVgfSq0W"
      },
      "outputs": [],
      "source": [
        "# @title Utility Functions {form-width: \"25%\"}\n",
        "\n",
        "\n",
        "def inference(frames, query_points):\n",
        "  \"\"\"Inference on one video.\n",
        "\n",
        "  Args:\n",
        "    frames: [num_frames, height, width, 3], [0, 255], np.uint8\n",
        "    query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x]\n",
        "\n",
        "  Returns:\n",
        "    tracks: [num_points, 3], [-1, 1], [t, y, x]\n",
        "    visibles: [num_points, num_frames], bool\n",
        "  \"\"\"\n",
        "  # Preprocess video to match model inputs format\n",
        "  frames = model_utils.preprocess_frames(frames)\n",
        "  query_points = query_points.astype(np.float32)\n",
        "  frames, query_points = frames[None], query_points[None]  # Add batch dimension\n",
        "\n",
        "  outputs = tapir(\n",
        "      video=frames,\n",
        "      query_points=query_points,\n",
        "      is_training=False,\n",
        "      query_chunk_size=32,\n",
        "  )\n",
        "  tracks, occlusions, expected_dist = (\n",
        "      outputs['tracks'],\n",
        "      outputs['occlusion'],\n",
        "      outputs['expected_dist'],\n",
        "  )\n",
        "\n",
        "  # Binarize occlusions\n",
        "  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)\n",
        "  return tracks[0], visibles[0]\n",
        "\n",
        "\n",
        "inference = jax.jit(inference)\n",
        "\n",
        "\n",
        "def sample_random_points(frame_max_idx, height, width, num_points):\n",
        "  \"\"\"Sample random points with (time, height, width) order.\"\"\"\n",
        "  y = np.random.randint(0, height, (num_points, 1))\n",
        "  x = np.random.randint(0, width, (num_points, 1))\n",
        "  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))\n",
        "  points = np.concatenate((t, y, x), axis=-1).astype(\n",
        "      np.int32\n",
        "  )  # [num_points, 3]\n",
        "  return points"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rIASK5A2Rj0X"
      },
      "outputs": [],
      "source": [
        "# @title Predict Sparse Point Tracks {form-width: \"25%\"}\n",
        "\n",
        "resize_height = 256  # @param {type: \"integer\"}\n",
        "resize_width = 256  # @param {type: \"integer\"}\n",
        "num_points = 20  # @param {type: \"integer\"}\n",
        "\n",
        "frames = media.resize_video(video, (resize_height, resize_width))\n",
        "query_points = sample_random_points(\n",
        "    0, frames.shape[1], frames.shape[2], num_points\n",
        ")\n",
        "tracks, visibles = inference(frames, query_points)\n",
        "tracks = np.array(tracks)\n",
        "visibles = np.array(visibles)\n",
        "\n",
        "# Visualize sparse point tracks\n",
        "height, width = video.shape[1:3]\n",
        "tracks = transforms.convert_grid_coordinates(\n",
        "    tracks, (resize_width, resize_height), (width, height)\n",
        ")\n",
        "video_viz = viz_utils.paint_point_track(video, tracks, visibles)\n",
        "media.show_video(video_viz, fps=10)"
      ]
    },
    {
      "metadata": {
        "id": "-2J3DHHtfVF5"
      },
      "cell_type": "code",
      "source": [
        "# @title Download TAPVid-DAVSIS Dataset {form-width: \"25%\"}\n",
        "!wget https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip\n",
        "!unzip tapvid_davis.zip"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "BTuJ_8LrfVF5"
      },
      "cell_type": "code",
      "source": [
        "# @title DAVIS First Eval on 256x256 Resolution {form-width: \"25%\"}\n",
        "%%time\n",
        "\n",
        "from tapnet import evaluation_datasets\n",
        "\n",
        "davis_dataset = evaluation_datasets.create_davis_dataset(\n",
        "    'tapvid_davis/tapvid_davis.pkl', query_mode='first'\n",
        ")\n",
        "\n",
        "summed_scalars = None\n",
        "for sample_idx, sample in enumerate(davis_dataset):\n",
        "  sample = sample['davis']\n",
        "  frames = np.round((sample['video'][0] + 1) / 2 * 255).astype(np.uint8)\n",
        "  query_points = sample['query_points'][0]\n",
        "\n",
        "  tracks, visibles = inference(frames, query_points)\n",
        "  tracks = np.array(tracks)\n",
        "  visibles = np.array(visibles)\n",
        "  occluded = ~visibles\n",
        "\n",
        "  query_points = sample['query_points'][0]\n",
        "\n",
        "  scalars = evaluation_datasets.compute_tapvid_metrics(\n",
        "      query_points[None],\n",
        "      sample['occluded'],\n",
        "      sample['target_points'],\n",
        "      occluded[None],\n",
        "      tracks[None],\n",
        "      query_mode='first',\n",
        "  )\n",
        "  scalars = jax.tree.map(lambda x: np.array(np.sum(x, axis=0)), scalars)\n",
        "  print(sample_idx, scalars)\n",
        "\n",
        "  if summed_scalars is None:\n",
        "    summed_scalars = scalars\n",
        "  else:\n",
        "    summed_scalars = jax.tree.map(np.add, summed_scalars, scalars)\n",
        "\n",
        "  num_samples = sample_idx + 1\n",
        "  mean_scalars = jax.tree.map(lambda x: x / num_samples, summed_scalars)\n",
        "  print(mean_scalars)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QpTDg7oQJDC3"
      },
      "outputs": [],
      "source": [
        "# @title Efficient Chunked Point Track Prediction {form-width: \"25%\"}\n",
        "\n",
        "resize_height = 256  # @param {type: \"integer\"}\n",
        "resize_width = 256  # @param {type: \"integer\"}\n",
        "num_points = 4096  # @param {type: \"integer\"}\n",
        "\n",
        "frames = media.resize_video(video, (resize_height, resize_width))\n",
        "frames = model_utils.preprocess_frames(frames[None])\n",
        "feature_grids = tapir.get_feature_grids(frames, is_training=False)\n",
        "query_points = sample_random_points(\n",
        "    0, frames.shape[2], frames.shape[3], num_points\n",
        ")\n",
        "chunk_size = 32\n",
        "\n",
        "\n",
        "def chunk_inference(query_points):\n",
        "  query_points = query_points.astype(np.float32)[None]\n",
        "\n",
        "  outputs = tapir(\n",
        "      video=frames,\n",
        "      query_points=query_points,\n",
        "      is_training=False,\n",
        "      query_chunk_size=chunk_size,\n",
        "      feature_grids=feature_grids,\n",
        "  )\n",
        "  tracks, occlusions, expected_dist = (\n",
        "      outputs[\"tracks\"],\n",
        "      outputs[\"occlusion\"],\n",
        "      outputs[\"expected_dist\"],\n",
        "  )\n",
        "\n",
        "  # Binarize occlusions\n",
        "  visibles = model_utils.postprocess_occlusions(occlusions, expected_dist)\n",
        "  return tracks[0], visibles[0]\n",
        "\n",
        "\n",
        "chunk_inference = jax.jit(chunk_inference)\n",
        "\n",
        "all_tracks = []\n",
        "all_visibles = []\n",
        "for chunk in range(0, query_points.shape[0], chunk_size):\n",
        "  tracks, visibles = chunk_inference(query_points[chunk : chunk + chunk_size])\n",
        "  all_tracks.append(np.array(tracks))\n",
        "  all_visibles.append(np.array(visibles))\n",
        "\n",
        "tracks = np.concatenate(all_tracks, axis=0)\n",
        "visibles = np.concatenate(all_visibles, axis=0)\n",
        "\n",
        "# Visualize sparse point tracks\n",
        "height, width = video.shape[1:3]\n",
        "tracks = transforms.convert_grid_coordinates(\n",
        "    tracks, (resize_width, resize_height), (width, height)\n",
        ")\n",
        "video_viz = viz_utils.paint_point_track(video, tracks, visibles)\n",
        "media.show_video(video_viz, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pY4hi_uTv6K4"
      },
      "outputs": [],
      "source": [
        "# @title Select Any Points at Any Frame {form-width: \"25%\"}\n",
        "\n",
        "select_frame = 0  # @param {type:\"slider\", min:0, max:49, step:1}\n",
        "\n",
        "# Generate a colormap with 20 points, no need to change unless select more than 20 points\n",
        "colormap = viz_utils.get_colors(20)\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(10, 5))\n",
        "ax.imshow(video[select_frame])\n",
        "ax.axis('off')\n",
        "ax.set_title(\n",
        "    'You can select more than 1 point. After selecting enough points, run the'\n",
        "    ' next cell.'\n",
        ")\n",
        "\n",
        "select_points = []\n",
        "\n",
        "\n",
        "# Event handler for mouse clicks\n",
        "def on_click(event):\n",
        "  if event.button == 1 and event.inaxes == ax:  # Left mouse button clicked\n",
        "    x, y = int(np.round(event.xdata)), int(np.round(event.ydata))\n",
        "\n",
        "    select_points.append(np.array([x, y]))\n",
        "\n",
        "    color = colormap[len(select_points) - 1]\n",
        "    color = tuple(np.array(color) / 255.0)\n",
        "    ax.plot(x, y, 'o', color=color, markersize=5)\n",
        "    plt.draw()\n",
        "\n",
        "\n",
        "fig.canvas.mpl_connect('button_press_event', on_click)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_-JUf0pfwFN2"
      },
      "outputs": [],
      "source": [
        "# @title Predict Point Tracks for the Selected Points {form-width: \"25%\"}\n",
        "\n",
        "resize_height = 256  # @param {type: \"integer\"}\n",
        "resize_width = 256  # @param {type: \"integer\"}\n",
        "\n",
        "\n",
        "def convert_select_points_to_query_points(frame, points):\n",
        "  \"\"\"Convert select points to query points.\n",
        "\n",
        "  Args:\n",
        "    points: [num_points, 2], in [x, y]\n",
        "\n",
        "  Returns:\n",
        "    query_points: [num_points, 3], in [t, y, x]\n",
        "  \"\"\"\n",
        "  points = np.stack(points)\n",
        "  query_points = np.zeros(shape=(points.shape[0], 3), dtype=np.float32)\n",
        "  query_points[:, 0] = frame\n",
        "  query_points[:, 1] = points[:, 1]\n",
        "  query_points[:, 2] = points[:, 0]\n",
        "  return query_points\n",
        "\n",
        "\n",
        "frames = media.resize_video(video, (resize_height, resize_width))\n",
        "query_points = convert_select_points_to_query_points(\n",
        "    select_frame, select_points\n",
        ")\n",
        "height, width = video.shape[1:3]\n",
        "query_points = transforms.convert_grid_coordinates(\n",
        "    query_points,\n",
        "    (1, height, width),\n",
        "    (1, resize_height, resize_width),\n",
        "    coordinate_format='tyx',\n",
        ")\n",
        "tracks, visibles = inference(frames, query_points)\n",
        "\n",
        "# Visualize sparse point tracks\n",
        "tracks = transforms.convert_grid_coordinates(\n",
        "    tracks, (resize_width, resize_height), (width, height)\n",
        ")\n",
        "video_viz = viz_utils.paint_point_track(video, tracks, visibles, colormap)\n",
        "media.show_video(video_viz, fps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8orfQRoaRJit"
      },
      "source": [
        "That's it!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {},
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
